# coding: utf-8
# NAME: Clahe - Automatic Histogram enhancement
# FILE: Clahe_Rev6 (3_6)
# REVISION : 1.20.0  - 2022/08/11
# AUTHOR : Maurizio Abbate
# Copyright(c) 2021 arivis AG, Germany. All Rights Reserved.
#
# Permission is granted to use, modify and distribute this code,
# as long as this copyright notice remains part of the code.
#
# PURPOSE : Apply the Clahe algorithm to the ImageSet
#           Clahe = Contrast Limited Adaptive Histogram Equalization
#
# Tested for V4d Release : 3.6
# Uses Skimage and OpenCv2 library
#
# NOTE: buffer rotation (90° along X or Y) before applying the clahe 
#       has been introduced
# 
# ------------------------------ External Package Import ----------------------
import time
import arivis
import arivis_objects as objects
import arivis_core as core
import numpy as np
import array
from skimage import exposure, img_as_ubyte,img_as_uint,img_as_float64
import cv2
#
# ----------------------  End of external Package Import ----------------------
#
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ USER SETTINGS @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
#
# -----------------------------------------------------------------------------
INPUT_CHANNEL = 0   # <---- Count start from 0 (Ch#1 == 0)
# name of the new channel (clahe storage)
OUTPUT_CH_NAME = "Clahe_"
# -----------------------------------------------------------------------------
# options
# -----------------------------------------------------------------------------
CLAHE_CV2 = True       #   True == Clahe using OpenCV2
HISTO_STRETCH = True   #   True == stretch the histogram of Skimage clahe
# -----------------------------------------------------------------------------
#   BLOCK_NUMBER_X -> # of X tiles
#   BLOCK_NUMBER_Y -> # of Y tiles
# -----------------------------------------------------------------------------
BLOCK_NUMBER_X = 10
BLOCK_NUMBER_Y = 10
CLIP_LIMIT_CV2 = 20      # <-- Range between 0 and 100
CLIP_LIMIT_SKIMAGE = 0.02   # <-- Range between 0 and 1
# -----------------------------------------------------------------------------
# ROTATION_ALONG_AXIS = 0 -> Z no action 
# ROTATION_ALONG_AXIS = 1 -> Y ROTATION_ALONG_AXIS = 2 -> X 
# -----------------------------------------------------------------------------
ROTATION_ALONG_AXIS = 0
# -----------------------------------------------------------------------------
ACTIVE_PLANE = False    #    True == active plane
ACTIVE_FRAME = True    #    True == active time point
# -----------------------------------------------------------------------------
#
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ END USER SETTINGS @@@@@@@@@@@@@@@@@@@@@@@@@@@@
#
# ------------------------------ Global variables  ---------------------------- 
#   WARNING .......  DON'T MODIFY the following variables
# -----------------------------------------------------------------------------
DELETE_ROTATED_IMAGESET = True
TEMP_OUTPUT_CH_NAME = "Temporary_"
# ----------------------  End of the Global variables  ------------------------
#
# @@@@@@@@@@@@@@@@@@@@@@@@@@ V4D Utility Function @@@@@@@@@@@@@@@@@@@@@@@@@@@@@
#
# ------------------------------ Script body ---------------------------------- 
# Function : GetEnviroment
# -----------------------------------------------------------------------------
def GetEnviroment():
  # ---------------------------------------------------------------------------
  # return the viewer and the imageset objects
  # ---------------------------------------------------------------------------
  viewer = arivis.App.get_active_viewer()
  imageset = viewer.get_imageset()
  if imageset is None :             #  if None == imageset :
    print("No Image Set open") 
    return viewer,None  
  # ---------------------------------------------------------------------------
  return viewer,imageset
# ------------------------------ Script body ----------------------------------
# End Function : GetEnviroment
# -----------------------------------------------------------------------------
#
# ------------------------------ Script body ----------------------------------
# Function : GetStorage
# -----------------------------------------------------------------------------
def GetStorage(imageset,bObject = False):
  # ---------------------------------------------------------------------------
  # return the storage object
  # ---------------------------------------------------------------------------
  if imageset is None :
    print("No Image Set open") 
    return None  
  # ---------------------------------------------------------------------------
  document = imageset.get_document()
  if document is None:  # None == document :
    print( "No Document open" )
    return None
  # ---------------------------------------------------------------------------
  store = document.get_store(imageset, objects.Store.DOCUMENT_STORE)
  if store is None:  #None == store :
    print ("No Measurements availble" )
    return None
  # ---------------------------------------------------------------------------
  # Get the full object's ID list - if the list's lenght is == 0 no storage 
  # ---------------------------------------------------------------------------
  IdList = store.get_object_ids("")
  #----------------------------------------------------------------------------  
  if 0 == len(IdList) and bObject == True:
    print ("No Objects" )
    return None
  # ---------------------------------------------------------------------------
  print ("Get_Storage - OK " )
  return store
# ------------------------------ Script body ----------------------------------
# End Function : GetStorage
# -----------------------------------------------------------------------------
#
# ------------------------------ Script body ----------------------------------
# Function : GetDocument
# -----------------------------------------------------------------------------
def GetDocument(imageset):
  # ---------------------------------------------------------------------------
  # return the document object
  # ---------------------------------------------------------------------------
  if imageset is None:
    print ("No Image Set open" )
    return None  
  # ---------------------------------------------------------------------------
  document = imageset.get_document()
  # ---------------------------------------------------------------------------
  if document is None:
    print ("No Document open" )
    return None
  # ---------------------------------------------------------------------------
  return document
# ------------------------------ Script body ----------------------------------
# End Function : GetDocument
# -----------------------------------------------------------------------------
#
# ----------------------------------------------------------------------------- 
# Function : PrepareChannel
# new channel - return channel index
# -----------------------------------------------------------------------------  
def PrepareChannel(imageset,chname):
  # --------------------------------------------------------------------------- 
  # Channel(s) is added at the end
  # --------------------------------------------------------------------------- 
  channels = imageset.get_channel_count()
  print( "channels " + str(channels))
  imageset.insert_channels(channels,1)
  if chname!="":
      lastCh = imageset.get_channel_count()
      chname__ = chname + str(lastCh) + "  [" + str(INPUT_CHANNEL +1 ) + "]"
      imageset.set_channel_name(lastCh - 1,chname__)
  return channels 
# ----------------------------------------------------------------------------- 
# End Function : PrepareChannel
# -----------------------------------------------------------------------------
#
# ----------------------------------------------------------------------------- 
# Function : GetPixelTypes()
# -----------------------------------------------------------------------------
def GetPixelTypes(imageset):
  # -------------------------------------------------------------------------
  # ritorna coefficente, dt, Type
  # -------------------------------------------------------------------------    
  if imageset is None :
    print( "No Image Set open" )
    return None  
  # -------------------------------------------------------------------------   
  pixelType = imageset.get_pixeltype()
  # -------------------------------------------------------------------------  
  if pixelType == core.ImageSet.PIXELTYPE_USHORT:
    typeP = 'H'    
    #dt = np.ushort
    dt = np.dtype(np.uint16) 
    Coefficent = int(65535)   
  elif pixelType == core.ImageSet.PIXELTYPE_ULONG:
    typeP = 'L' 
    #dt = np.uint
    dt = np.dtype(np.uint)
    Coefficent = int(65535)
  elif pixelType == core.ImageSet.PIXELTYPE_FLOAT:
    typeP = 'f' 
    #dt = np.float32
    dt = np.dtype(np.float32)
    Coefficent = int(65535)
  else:     #core.ImageSet.PIXELTYPE_UCHAR:
    Coefficent = int(255)
    typeP = 'B'
    #dt = np.uint8
    dt = np.dtype(np.uint8)
  # -------------------------------------------------------------------------  
  return dt,typeP,Coefficent 
# ----------------------------------------------------------------------------- 
# End Function : GetPixelTypes()
# ----------------------------------------------------------------------------- 
#
# ----------------------------------------------------------------------------- 
# Function : RotateImageset()
# -----------------------------------------------------------------------------
def RotateImageset(imageset,Axis,Input_Channel,Output_Channel=0):
    # -------------------------------------------------------------------------  
    if imageset is None :
        print("ImageSet not available")
        return None    
    # -------------------------------------------------------------------------  
    if Axis <0 or Axis>2 :
        print("Wrong axis index (0-2)")
        return None        
    # -------------------------------------------------------------------------  
    if Axis == 0 : return imageset    # Axis == 0 (Z)  no rotation
    # -------------------------------------------------------------------------
    channels = imageset.get_channel_count() 
    if Input_Channel >= channels:    
        print("Wrong input channel")
        return None               
    # ------------------------------------------------------------------------- 
    startTime = time.time()
    Asse = "Y"
    if Axis == 2: Asse = "X" 
    print ("[RotateImageset] Starts: (Axis=" + Asse + "  from Channel = " + str(Input_Channel) + ")")      
    # ------------------------------------------------------------------------- 
    pixelsize = list(imageset.get_pixel_size()) # XYZ
    # ------------------------------------------------------------------------- 
    BoundB = imageset.get_bounding_box()
    doc = imageset.get_document()
    listImageSet = doc.get_imagesets()
    id1 =  len(listImageSet)
    newImageSet = doc.create_imageset(str(id1+1), imageset.get_pixeltype(),channels=1)
    newImageSet.set_name(TEMP_OUTPUT_CH_NAME)
    newImageSet.insert_timepoints(0, timepoint_count =  1)
    # ------------------------------------------------------------------------- 
    # the newImageset is rotated
    # -------------------------------------------------------------------------  
    timepoint = BoundB.t1
    numberofplanes=1
    Bounds2D = core.Bounds2D()
    Bounds3D = core.Bounds3D() 
    newImageSet.set_pixel_size(pixelsize[0],pixelsize[1],pixelsize[2])
    index = 1
    if Axis == 1:
        # ---------------------------------------------------------------------
        numberofplanes = abs(BoundB.x2 - BoundB.x1 + 1)
        Bounds2D.x = 0
        Bounds2D.y = 0 
        Bounds2D.width = abs(BoundB.z2 - BoundB.z1 + 1)
        Bounds2D.height = abs(BoundB.y2 - BoundB.y1 + 1)  
        index = 2        
        # ---------------------------------------------------------------------
        Bounds3D.x1 = BoundB.z1
        Bounds3D.x2 = BoundB.z2     
        Bounds3D.y1 = BoundB.y1
        Bounds3D.y2 = BoundB.y2     
        Bounds3D.z1 = BoundB.x1
        Bounds3D.z2 = BoundB.x2     
        # ---------------------------------------------------------------------
        newImageSet.set_pixel_size(pixelsize[2],pixelsize[1],pixelsize[0])
        # ---------------------------------------------------------------------
    elif Axis == 2:
        # ---------------------------------------------------------------------
        numberofplanes = abs(BoundB.y2 - BoundB.y1 + 1)
        Bounds2D.x = 0
        Bounds2D.y = 0 
        Bounds2D.width = abs(BoundB.x2 - BoundB.x1 + 1) 
        Bounds2D.height = abs(BoundB.z2 - BoundB.z1 + 1)   
        index = 1           
        # ---------------------------------------------------------------------
        Bounds3D.x1 = BoundB.x1
        Bounds3D.x2 = BoundB.x2     
        Bounds3D.y1 = BoundB.z1
        Bounds3D.y2 = BoundB.z2     
        Bounds3D.z1 = BoundB.y1
        Bounds3D.z2 = BoundB.y2     
        # ---------------------------------------------------------------------
        newImageSet.set_pixel_size(pixelsize[0],pixelsize[2],pixelsize[1])
        # ---------------------------------------------------------------------
    newImageSet.insert_planes(0, 0, numberofplanes, Bounds2D)
    # -------------------------------------------------------------------------    
    buffer = imageset.read_imagedata(BoundB, Input_Channel, timepoint)
    buffer1 = np.swapaxes(buffer, index, 0)
    #buffer1 = np.moveaxis(buffer, 0, index)    
    newImageSet.write_imagedata(buffer1,Bounds3D, Output_Channel, timepoint)
    # -------------------------------------------------------------------------
    print ("[RotateImageset] time: " + str(time.time() - startTime))
    return newImageSet 
# ----------------------------------------------------------------------------- 
# End Function : RotateImageset()
# ----------------------------------------------------------------------------- 
#
# ----------------------------------------------------------------------------- 
# Function : RotateImagesetBack()
# -----------------------------------------------------------------------------
def RotateImagesetBack(imagesetS,imagesetD,Axis,Input_Channel,Output_Channel=-1):
    # -------------------------------------------------------------------------  
    if imagesetS is None or imagesetD is None:
        print("ImageSet not available")
        return None    
    # -------------------------------------------------------------------------  
    if Axis <0 or Axis>2 :
        print("Wrong axis index (0-2)")
        return None   
    # -------------------------------------------------------------------------  
    if Axis == 0 : return imagesetS      # Axis == 0 (Z)  no rotation
    # -------------------------------------------------------------------------
    channels = imagesetS.get_channel_count() 
    if Input_Channel >= channels:    
        print("Wrong input channel")
        return None 
    # ------------------------------------------------------------------------- 
    startTime = time.time()
    Asse = "Y"
    if Axis == 2: Asse = "X" 
    print ("[RotateImagesetBack] Starts: (Axis=" + Asse + "  from Channel = " + str(Input_Channel) + ")")       
    # ------------------------------------------------------------------------- 
    if Output_Channel < 0:
        Output_Channel = PrepareChannel(imagesetD,OUTPUT_CH_NAME)                 
    # ------------------------------------------------------------------------- 
    #pixelsize = list(imagesetS.get_pixel_size()) # XYZ
    # ------------------------------------------------------------------------- 
    BoundD = imagesetD.get_bounding_box()     
    BoundB = imagesetS.get_bounding_box()    
    timepoint = BoundB.t1
    buffer = imagesetS.read_imagedata(BoundB, Input_Channel, timepoint)
    if Axis == 2: 
        buffer1 = np.swapaxes(buffer, 0, 1)
    elif Axis == 1:  
        buffer1 = np.swapaxes(buffer, 0, 2)
    imagesetD.write_imagedata(buffer1,BoundD,Output_Channel, timepoint)        
    # -------------------------------------------------------------------------
    # source dataset deleted?
    # ------------------------------------------------------------------------- 
    if DELETE_ROTATED_IMAGESET==True:
        print ("[RotateImagesetBack] Rotated ImageSet is deleted:")
        doc = imagesetS.get_document()
        doc.delete_imageset(imagesetS)
        doc.set_default_imageset(imagesetD)
    # -------------------------------------------------------------------------    
    print ("[RotateImagesetBack] time: " + str(time.time() - startTime))
    return imagesetD
# ----------------------------------------------------------------------------- 
# End Function : RotateImagesetBack()
# ----------------------------------------------------------------------------- 
#
# ----------------------------------------------------------------------------- 
# Function : Clahe
# -----------------------------------------------------------------------------  
def Clahe(imageset,viewer,Input_Channel,Output_Channel):
    # -------------------------------------------------------------------------  
    if imageset is None or viewer is None:
        print("ImageSet not available")
        return False
    # ------------------------------------------------------------------------- 
    startTime = time.time()
    print ("[Clahe] Starts: ")    
    # ------------------------------------------------------------------------- 
    dt,typeP,Coefficent = GetPixelTypes(imageset)
    if dt != np.dtype(np.uint16) and dt != np.dtype(np.uint8):
        print("Image type not supported [Clahe can be applied to 16bit and 8bit images only]")
        return False
    # -------------------------------------------------------------------------  
    # get the ImageSet info
    # -------------------------------------------------------------------------      
    planeRect = imageset.get_bounding_rectangle()      
    pixelCount = planeRect.width * planeRect.height
    frames = imageset.get_timepoint_count()
    planes = imageset.get_plane_count(0)        # <---TIME POINTS
    active_plane = viewer.get_plane()
    active_frame = viewer.get_timepoint()
    # -------------------------------------------------------------------------
    #   Check if BLOCK_NUMBER_X and BLOCK_NUMBER_Y bigger than 0
    # -------------------------------------------------------------------------
    block_number_x = BLOCK_NUMBER_X
    if block_number_x <=0:
        block_number_x = 1
    block_number_y = BLOCK_NUMBER_Y        
    if BLOCK_NUMBER_Y <=0:
        block_number_y = 1        
    # ------------------------------------------------------------------------- 
    planestart = 0
    planeend = planes
    if ACTIVE_PLANE == True and ROTATION_ALONG_AXIS == 0 :    #    True == active plane
      planestart = active_plane
      planeend = active_plane + 1     
    # -------------------------------------------------------------------------       
    framestart = 0
    frameend = frames  
    if ACTIVE_FRAME == True:    #    True == active time point
      framestart = active_frame
      frameend = active_frame + 1    
    # ------------------------------------------------------------------------- 
    bound4 = imageset.get_bounding_box()
    # -------------------------------------------------------------------------
    # Loop on Time and Z
    # -------------------------------------------------------------------------
    for frame in range(framestart, frameend):      # 
        # ---------------------------------------------------------------------    
        bound4.t1 = frame
        bound4.t2 = frame        
        # ---------------------------------------------------------------------
        # Loop on Z
        # ---------------------------------------------------------------------
        for plane in range(planestart, planeend):   
            # -----------------------------------------------------------------
            bound4.z1 = plane
            bound4.z2 = plane                             
            # -----------------------------------------------------------------
            # buffer allocation
            # -----------------------------------------------------------------
            print( "Plane :" + str(plane) + "  Time Point :" + str(frame))
            # ----------------------------------------------------------------
            inputBufferOri = array.array(typeP)
            zeroBuf = [0]*pixelCount
            inputBufferOri.fromlist(zeroBuf)
            # ----------------------------------------------------------------
            # buffer conversione from array.array to Numpy.array
            # ----------------------------------------------------------------
            imageset.read_imagedata(bound4, Input_Channel, frame, buffer = inputBufferOri)
            #imageset.get_channeldata(planeRect, planeRect, inputBufferOri, INPUT_CHANNEL, plane, frame)
            NPBuffer1D = np.array(inputBufferOri,dtype = typeP)
            Min1,Max1 = np.min(NPBuffer1D),np.max(NPBuffer1D)
            NPBuffer2D = np.reshape(NPBuffer1D,(planeRect.height , planeRect.width ))    
            # ----------------------------------------------------------------
            if CLAHE_CV2==True:
                # ------------------------------------------------------------
                # RUN CLAHE OpenCv
                # ------------------------------------------------------------
                # The declaration of CLAHE 
                # clipLimit -> Threshold for contrast limiting               
                clahe = cv2.createCLAHE(clipLimit = CLIP_LIMIT_CV2,tileGridSize=(block_number_x,block_number_y))
                NPBuffer2D__= clahe.apply(NPBuffer2D)
                # ------------------------------------------------------------
                print("block number (XY) : " + str(BLOCK_NUMBER_X) + " - " + str(BLOCK_NUMBER_Y))
                # ------------------------------------------------------------                
            else:
                # ------------------------------------------------------------
                # RUN CLAHE Skimage
                # ------------------------------------------------------------
                NPBuffer2D = img_as_float64(NPBuffer2D)
                block_size_x = round(planeRect.width/block_number_x)
                block_size_y = round(planeRect.height/block_number_y)                
                print("block size (XY) : " + str(block_size_x) + " - " + str(block_size_y))
                NPBuffer2D__ = exposure.equalize_adapthist(NPBuffer2D, kernel_size = (block_size_x,block_size_y), clip_limit=CLIP_LIMIT_SKIMAGE)                     
                #NPBuffer2D__ = exposure.equalize_adapthist(NPBuffer2D, kernel_size = WindowRect, clip_limit=CLIP_LIMIT)                            
                #Min2,Max2 = np.min(NPBuffer2D__),np.max(NPBuffer2D__)
                # -------------------------------------------------------------
                if NPBuffer2D__ is not None:
                    # ---------------------------------------------------------
                    # crop the image to match the original size - Debug
                    # ---------------------------------------------------------
                    if HISTO_STRETCH == True:
                        if dt == np.dtype(np.uint16) :
                            NPBuffer2D__ = img_as_uint(NPBuffer2D__)                          
                        if dt == np.dtype(np.uint8):
                            NPBuffer2D__ = img_as_ubyte(NPBuffer2D__)
                    else:
                        if dt == np.dtype(np.uint16) :
                            NPBuffer2D__ = (NPBuffer2D__ *  Max1).astype('H' )
                        if dt == np.dtype(np.uint8):
                             NPBuffer2D__ = (NPBuffer2D__ *  Max1).astype('B' )                             
            # ------------------------------------------------------------
            # save the result
            # ------------------------------------------------------------
            NPBuffer1D = np.reshape(NPBuffer2D__,(planeRect.width * planeRect.height))  
            outputArray = np.array(NPBuffer1D)
            outputBufferOri = array.array(typeP)  
            outputBufferOri.fromlist(outputArray.tolist())   
            imageset.write_imagedata(outputBufferOri, bound4, Output_Channel, frame)
            #imageset.set_channeldata(planeRect, outputBufferOri, Output_Channel, plane, frame)    
            # ----------------------------------------------------------------           
    print ("[Clahe] time: " + str(time.time() - startTime))
    return True  
# ----------------------------------------------------------------------------
# End Function : Clahe
# ----------------------------------------------------------------------------
#
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@  Main  @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
#
# ------------------------------ Script body ---------------------------------
# MAIN BODY
# ----------------------------------------------------------------------------   
# helper to get execution time
startTime = time.time()
# -----------------------------------------------------------------------------  
print ("Script is running ........ " )
#
# ----------------------------------------------------------------------------- 
# Skeleton & End/Points section
# ----------------------------------------------------------------------------- 
viewer,imageset = GetEnviroment()          
# ----------------------------------------------------------------------------- 
Rotate_Ch = 0
newimageset = RotateImageset(imageset,ROTATION_ALONG_AXIS,INPUT_CHANNEL,Rotate_Ch)
# -----------------------------------------------------------------------------     
if newimageset is not None:
    # -------------------------------------------------------------------------
    # create a new channel to temporary store the rotated image
    # -------------------------------------------------------------------------
    if newimageset != imageset:
        Chnome = TEMP_OUTPUT_CH_NAME + str(Rotate_Ch)
    else:
        Chnome = OUTPUT_CH_NAME
        Rotate_Ch = INPUT_CHANNEL
    Output_Channel = PrepareChannel(newimageset,Chnome) 
    Clahe(newimageset,viewer,Rotate_Ch,Output_Channel)     
    imageset = RotateImagesetBack(newimageset,imageset,ROTATION_ALONG_AXIS,Output_Channel)
    if imageset is  None:
      print ("Image copy Error.....")                 
else:
      print ("Image Rotation Error....")      
# ---------------------------------------------------------------------------- 
print ("script time: " + str(time.time() - startTime))
# ------------------------------ Script body ---------------------------------- 
# End of main body
# -----------------------------------------------------------------------------   

